from federatedscope.core.configs.config import CN
from federatedscope.core.configs.yacs_config import Argument
from federatedscope.register import register_config


def extend_fl_algo_cfg(cfg):
    # ---------------------------------------------------------------------- #
    # fedopt related options, a general fl algorithm
    # ---------------------------------------------------------------------- #
    cfg.fedopt = CN()

    cfg.fedopt.use = False

    cfg.fedopt.optimizer = CN(new_allowed=True)
    cfg.fedopt.optimizer.type = Argument(
        'SGD', description="optimizer type for FedOPT")
    cfg.fedopt.optimizer.lr = Argument(
        0.01, description="learning rate for FedOPT optimizer")
    cfg.fedopt.annealing = False
    cfg.fedopt.annealing_step_size = 2000
    cfg.fedopt.annealing_gamma = 0.5

    # ---------------------------------------------------------------------- #
    # fedprox related options, a general fl algorithm
    # ---------------------------------------------------------------------- #
    cfg.fedprox = CN()

    cfg.fedprox.use = False
    cfg.fedprox.mu = 0.

    # ---------------------------------------------------------------------- #
    # fedswa related options, Stochastic Weight Averaging (SWA)
    # ---------------------------------------------------------------------- #
    cfg.fedswa = CN()
    cfg.fedswa.use = False
    cfg.fedswa.freq = 10
    cfg.fedswa.start_rnd = 30

    # ---------------------------------------------------------------------- #
    # Personalization related options, pFL
    # ---------------------------------------------------------------------- #
    cfg.personalization = CN()

    # client-distinct param names, e.g., ['pre', 'post']
    cfg.personalization.local_param = []
    cfg.personalization.share_non_trainable_para = False
    cfg.personalization.local_update_steps = -1
    # @regular_weight:
    # The smaller the regular_weight is, the stronger emphasising on
    # personalized model
    # For Ditto, the default value=0.1, the search space is [0.05, 0.1, 0.2,
    # 1, 2]
    # For pFedMe, the default value=15
    cfg.personalization.regular_weight = 0.1

    # @lr:
    # 1) For pFedME, the personalized learning rate to calculate theta
    # approximately using K steps
    # 2) 0.0 indicates use the value according to optimizer.lr in case of
    # users have not specify a valid lr
    cfg.personalization.lr = 0.0

    cfg.personalization.K = 5  # the local approximation steps for pFedMe
    cfg.personalization.beta = 1.0  # the average moving parameter for pFedMe

    # parameters for FedRep：
    cfg.personalization.lr_feature = 0.1  # learning rate: feature extractors
    cfg.personalization.lr_linear = 0.1  # learning rate: linear head
    cfg.personalization.epoch_feature = 1  # training epoch number
    cfg.personalization.epoch_linear = 2  # training epoch number
    cfg.personalization.weight_decay = 0.0

    # ---------------------------------------------------------------------- #
    # FedSage+ related options, gfl
    # ---------------------------------------------------------------------- #
    cfg.fedsageplus = CN()

    # Number of nodes generated by the generator
    cfg.fedsageplus.num_pred = 5
    # Hidden layer dimension of generator
    cfg.fedsageplus.gen_hidden = 128
    # Hide graph portion
    cfg.fedsageplus.hide_portion = 0.5
    # Federated training round for generator
    cfg.fedsageplus.fedgen_epoch = 200
    # Local pre-train round for generator
    cfg.fedsageplus.loc_epoch = 1
    # Coefficient for criterion number of missing node
    cfg.fedsageplus.a = 1.0
    # Coefficient for criterion feature
    cfg.fedsageplus.b = 1.0
    # Coefficient for criterion classification
    cfg.fedsageplus.c = 1.0

    # ---------------------------------------------------------------------- #
    # GCFL+ related options, gfl
    # ---------------------------------------------------------------------- #
    cfg.gcflplus = CN()

    # Bound for mean_norm
    cfg.gcflplus.EPS_1 = 0.05
    # Bound for max_norm
    cfg.gcflplus.EPS_2 = 0.1
    # Length of the gradient sequence
    cfg.gcflplus.seq_length = 5
    # Whether standardized dtw_distances
    cfg.gcflplus.standardize = False

    # ---------------------------------------------------------------------- #
    # FLIT+ related options, gfl
    # ---------------------------------------------------------------------- #
    cfg.flitplus = CN()

    cfg.flitplus.tmpFed = 0.5  # gamma in focal loss (Eq.4)
    cfg.flitplus.lambdavat = 0.5  # lambda in phi (Eq.10)
    cfg.flitplus.factor_ema = 0.8  # beta in omega (Eq.12)
    cfg.flitplus.weightReg = 1.0  # balance lossLocalLabel and lossLocalVAT

    # --------------- register corresponding check function ----------
    cfg.register_cfg_check_fun(assert_fl_algo_cfg)


def assert_fl_algo_cfg(cfg):
    if cfg.personalization.local_update_steps == -1:
        # By default, use the same step to normal mode
        cfg.personalization.local_update_steps = \
            cfg.train.local_update_steps
        cfg.personalization.local_update_steps = \
            cfg.train.local_update_steps

    if cfg.personalization.lr <= 0.0:
        # By default, use the same lr to normal mode
        cfg.personalization.lr = cfg.train.optimizer.lr

    if cfg.fedswa.use:
        assert cfg.fedswa.start_rnd < cfg.federate.total_round_num, \
            f'`cfg.fedswa.start_rnd` {cfg.fedswa.start_rnd} must be smaller ' \
            f'than `cfg.federate.total_round_num` ' \
            f'{cfg.federate.total_round_num}.'


register_config("fl_algo", extend_fl_algo_cfg)
